'''
Module for initialising parameters and building the model
'''

import theano
import theano.tensor as tensor
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
import numpy
from collections import OrderedDict
from utils import _p, ortho_weight, norm_weight, xavier_weight, l2norm
from layers import get_layer

c = 10.

def set_defaults(options):
    options['batch_size'] = options.get('batch_size', None)
    # Regularisation
    options['l1'] = options.get('l1', 0.0)
    options['l2'] = options.get('l2', 0.0)
    # Optimisation method
    options['opt'] = options.get('opt', 'adam')
    options['pooling'] = options.get('pooling', False)
    options['pooling_activ'] = options.get('pooling_activ', 'sum')
    options['lr'] = options.get('learning_rate', 0.1)
    options['max_itr'] = options.get('max_itr', 25000)
    # options['debug'] = options.get('debug', False)
    options['debug'] = options.get('debug', False)
    options['ar_layers'] = options.get('ar_layers', 1)
    options['dropout'] = options.get('dropout', False)
    options['dropout_rate'] = options.get('dropout_rate', 0.2)
    # options['hidden_units'] = options.get('hidden_units', [0])
    options['expert1_units'] = options.get('expert1_units', [0])
    options['expert2_units'] = options.get('expert2_units', [0])
    options['expert3_units'] = options.get('expert3_units', [0])
    options['expert4_units'] = options.get('expert4_units', [0])
    options['expert5_units'] = options.get('expert5_units', [0])
    options['manager_units'] = options.get('manager_units', [0])
    options['shared_ld'] = options.get('shared_ld', True)
    options['activ'] = options.get('activ', ['linear'] *
                                   len(options['manager_units']))
    # weighting of the validation set relative to the training set for
    # selecting best performing model
    options['percent_valid'] = options.get('percent_valid', 0.)
    options['model_seed'] = options.get('model_seed', None)
    # MAML
    options['theta'] = options.get('theta', 0.0)
    return options


def init_params(options, rng=None):
    """
    Initialize all network parameters and constrains.

    All parameters and their corresponding constraints are stored in an OrderedDict.
    """
    params = OrderedDict()
    constraints = OrderedDict()

    input_size = 2 # number of player utilities

    n_e1 = [input_size] + options['expert1_units']
    n_e2 = [input_size] + options['expert2_units']
    n_e3 = [input_size] + options['expert3_units']
    n_e4 = [input_size] + options['expert4_units']
    n_e5 = [input_size] + options['expert5_units']

    n_hidden_m = [input_size] + options['manager_units']

    for i in range(1, len(n_hidden_m)):
        params = get_layer('manager')[0](options, params, prefix='manager%02d' % i,
                                     nin=n_hidden_m[i-1],
                                     nout=n_hidden_m[i],
                                     rng=rng, b_offset=1. )
    params = get_layer('softmax2')[0](options, params, nin=n_hidden_m[-1],
                                     rng=rng)
    # expert 1
    for j in range(1, len(n_e1)):
        params = get_layer('expert1')[0](options, params, prefix='expert1%02d' % j,
                                     nin=n_e1[j-1],
                                     nout=n_e1[j],
                                     rng=rng, b_offset=0.)
    params = get_layer('softmax2')[0](options, params, nin=n_e1[-1],
                                     rng=rng)
    # expert 2
    for i in range(1, len(n_e2)):
        params = get_layer('expert2')[0](options, params, prefix='expert2%02d' % i,
                                     nin=n_e2[i-1],
                                     nout=n_e2[i],
                                     rng=rng, b_offset=1. )
    params = get_layer('softmax2')[0](options, params, nin=n_e2[-1],
                                     rng=rng)
    # expert 3
    for i in range(1, len(n_e3)):
        params = get_layer('expert3')[0](options, params, prefix='expert3%02d' % i,
                                     nin=n_e3[i - 1],
                                     nout=n_e3[i],
                                     rng=rng, b_offset=0.)
    params = get_layer('softmax2')[0](options, params, nin=n_e3[-1],
                                     rng=rng)
    # expert 4
    for i in range(1, len(n_e4)):
        params = get_layer('expert4')[0](options, params, prefix='expert4%02d' % i,
                                     nin=n_e4[i-1],
                                     nout=n_e4[i],
                                     rng=rng, b_offset=0. )
    params = get_layer('softmax2')[0](options, params, nin=n_e4[-1],
                                     rng=rng)
    # expert 5
    for i in range(1, len(n_e5)):
        params = get_layer('expert5')[0](options, params, prefix='expert5%02d' % i,
                                     nin=n_e5[i - 1],
                                     nout=n_e5[i],
                                     rng=rng, b_offset=0. )
    params = get_layer('softmax2')[0](options, params, nin=n_e5[-1],
                                     rng=rng)

    ar_layers = options['ar_layers']

    for i in range(ar_layers):
        for p in range(2):
            if i == ar_layers - 1 and p == 1:
                # don't build ar layer for pl 2 in the last layer because it is not used
                continue
            params, constraints = get_layer('ar')[0](options, params,
                                        prefix='p%d_ar%d' % (p, i),
                                        nin=ar_layers, level=i, rng=rng,
                                        constraints=constraints)
    params, constraints = get_layer('output')[0](options, params, constraints, rng=rng, nin=ar_layers)
    return params, constraints

def to_list(x, n):
    return x if isinstance(x, list) else [x] * n

def build_features(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Feature Layers components of the network.
    '''
    use_dropout = options['dropout']
    if use_dropout:
        print ('Using dropout')
    prev = x

    hidden_outputs = []
    n_hidden = len(options['hidden_units'])     # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if options['pooling']:
            # Add pooling units
            prev = get_layer('pooling')[1](tparams, prev, options, activ=options['pooling_activ'])
        prev = get_layer('hid')[1](tparams, prev, options,
                                   prefix='hidden%02d' % (i + 1),
                                   activ=activ[i])
        hidden_outputs.append(prev)
        if use_dropout:
            prev = get_layer('dropout')[1](prev, use_noise, options, trng)

    out = get_layer('sum')[1](tparams, prev, options)

    if normalise:
        out = get_layer('softmax')[1](tparams, out, options)

    return out, hidden_outputs

# Expert 1
def build_expert1(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Expert Layers components of the network.
    '''
    prev = x
    hidden_outputs = []
    n_hidden = len(options['expert1_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if i == 1:
            prev = get_layer('pooling')[1](tparams, prev, options, activ=options['pooling_activ'], type='r')
            prev = prev * c
        prev = get_layer('expert1')[1](tparams, prev, options,
                                   prefix='expert1%02d' % (i + 1),
                                   activ=activ[i])
        hidden_outputs.append(prev)
    out = get_layer('softmax2')[1](tparams, prev, options)
    return out, hidden_outputs

# Expert 2
def build_expert2(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Expert Layers components of the network.
    '''
    prev = x
    hidden_outputs = []
    n_hidden = len(options['expert2_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if i == 1:
            prev = get_layer('pooling')[1](tparams, prev, options, activ=options['pooling_activ'], type='r')
            prev = prev * c
        prev = get_layer('expert2')[1](tparams, prev, options,
                                   prefix='expert2%02d' % (i + 1),
                                   activ=activ[i])
        hidden_outputs.append(prev)

    out = get_layer('softmax2')[1](tparams, prev, options)
    return out, hidden_outputs

# Expert 3
def build_expert3(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Expert Layers components of the network.
    '''
    prev = x
    opp = x.transpose((0, 1, 3, 2))[:, [1, 0], :, :]
    opp = opp.transpose((0, 1, 3, 2))
    hidden_outputs = []
    n_hidden = len(options['expert3_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if options['pooling']:
            prev1 = tensor.nnet.relu(prev + opp)
            prev2 = tensor.nnet.relu(-opp - prev)
            prev = prev1 + prev2
        prev = get_layer('expert3')[1](tparams, prev, options,
                                   prefix='expert3%02d' % (i + 1),
                                   activ=activ[i])
        prev = prev * c
        hidden_outputs.append(prev)

    out = get_layer('softmax2')[1](tparams, prev, options)
    return out, hidden_outputs

# Expert 4
def build_expert4(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Expert Layers components of the network.
    '''
    prev = x
    hidden_outputs = []
    n_hidden = len(options['expert4_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if i == 0:
            # Add pooling units
            prev_c = get_layer('pooling')[1](tparams, prev, options, activ=options['pooling_activ'], type='c')
            prev = prev_c - prev
        if i == 1:
            prev = get_layer('pooling')[1](tparams, prev, options, activ=options['pooling_activ'], type='r')
            prev = prev * c
        prev = get_layer('expert4')[1](tparams, prev, options,
                                       prefix='expert4%02d' % (i + 1),
                                       activ=activ[i])
        hidden_outputs.append(prev)
    out = get_layer('softmax2')[1](tparams, prev, options)
    return out, hidden_outputs

# Expert 5
def build_expert5(x, tparams, options, use_noise, trng, normalise=True):
    '''
    Build the Expert Layers components of the network.
    '''
    prev = x
    opp = x.transpose((0, 1, 3, 2))[:, [1, 0], :, :]
    opp = opp.transpose((0, 1, 3, 2))
    hidden_outputs = []
    n_hidden = len(options['expert5_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        if i == 0:
            # Add pooling units
            prev1 = tensor.nnet.relu(prev - opp)
            prev2 = tensor.nnet.relu(opp - prev)
            prev = prev1 + prev2

        prev = get_layer('expert5')[1](tparams, prev, options,
                                       prefix='expert5%02d' % (i + 1),
                                       activ=activ[i])
        prev = prev * c
        hidden_outputs.append(prev)
    out = get_layer('softmax2')[1](tparams, prev, options)
    return out, hidden_outputs

def build_gating(x, tparams, options):
    '''
    Build the Gating Layers components of the network.
    '''
    prev = x
    hidden_outputs = []
    n_hidden = len(options['manager_units'])  # n_hidden=2
    activ = to_list(options['activ'], n_hidden)

    for i in range(n_hidden):
        prev = get_layer('manager')[1](tparams, prev, options,
                                   prefix='manager%02d' % (i + 1),
                                   activ=activ[i])

        hidden_outputs.append(prev)

    out = get_layer('mean')[1](tparams, prev, options)
    out = get_layer('softmax')[1](tparams, out, options)
    return out, hidden_outputs

def build_ar_layers(x, tparams, options, features, hiddens):
    u1, u2 = (x[:, 0, :, :], x[:, 1, :, :].transpose(0, 2, 1))
    h1, h2 = hiddens
    # concatinate the payoff matrix onto the final layer hidden units
    utility = (tensor.concatenate((u1.reshape((u1.shape[0],
                                                1,
                                                u1.shape[1],
                                                u1.shape[2])), h1), axis=1),
                tensor.concatenate((u2.reshape((u2.shape[0],
                                                1,
                                                u2.shape[1],
                                                u2.shape[2])), h2), axis=1))

    ar_layers = options['ar_layers']

    ar_lists = ([], [])
    opp = [None, None]
    weighted_feature_list = ([], [])
    br_list = ([], [])
    for i in range(ar_layers):
        for p in range(2):
            if i == (ar_layers - 1) and p == 1:
                continue  # don't build ar layer for pl 2 in the last layer
            feat = features[p]
            ar, weighted_features, br = get_layer('ar')[1](tparams,
                                                           feat,
                                                           options,
                                                           payoff=utility[p],
                                                           prefix='p%d_ar%d' % (p, i),
                                                           opposition=opp[p],
                                                           level=i)
            n, d = ar.shape
            ar = ar.reshape((n, 1, d))  # make space to concat ar layers
            weighted_feature_list[p].append(weighted_features)
            if i == 0:
                ar_lists[p].append(ar)
            else:
                ar_lists[p].append(tensor.concatenate((ar_lists[p][i - 1], ar),
                                   axis=1))
                br_list[p].append(br)

        # append each layer then update the opposition variable...
        if i < ar_layers - 1:
            for p in range(2):
                opp[1 - p] = ar_lists[p][i]
    return ar_lists, weighted_feature_list, br_list

def build_model(tparams, options, rng=None):
    """
    Computation graph for the model
    """
    if rng is None:
        rng = numpy.random.RandomState(123)
    trng = RandomStreams(rng.randint(1000000))
    use_noise = theano.shared(numpy.float32(0.))
    x = tensor.tensor4('x')

    own_features, hidden1 = build_features(x, tparams, options, use_noise, trng)
    opp_features, hidden2 = build_features(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], # transpose to get player 2 model
                                           tparams, options, use_noise, trng)

    ar, weighted_feature_list, br_list = build_ar_layers(x, tparams, options,
                                                         (own_features,
                                                          opp_features),
                                                         (hidden1[-1],
                                                          hidden2[-1]))
    ar_layers = options['ar_layers']
    out = get_layer('output')[1](tparams, ar[0][ar_layers-1], options)

    intermediate_fns = {'ar': ar,
                        'own_features': own_features,
                        'opp_features': opp_features,
                        'hidden1': hidden1,
                        'hidden2': hidden2,
                        'weighted_feature_list': weighted_feature_list,
                        'br_list': br_list}
    if not options['debug']:
        return trng, use_noise, x, out
    else:
        return trng, use_noise, x, out, intermediate_fns

def build_moe_model(x, tparams, options, rng=None):
    """
    Computation graph for the model
    """
    if rng is None:
        rng = numpy.random.RandomState(123)
    trng = RandomStreams(rng.randint(1000000))
    use_noise = theano.shared(numpy.float32(0.))
    theta = tensor.tensor4('theta')

    own_expert1_out, own_expert1 = build_expert1(x, tparams, options, use_noise, trng)
    own_expert2_out, own_expert2 = build_expert2(x, tparams, options, use_noise, trng)
    own_expert3_out, own_expert3 = build_expert3(x, tparams, options, use_noise, trng)
    own_expert4_out, own_expert4 = build_expert4(x, tparams, options, use_noise, trng)
    own_expert5_out, own_expert5 = build_expert5(x, tparams, options, use_noise, trng)
    own_expert = tensor.stack([own_expert1_out, own_expert2_out, own_expert3_out, own_expert4_out, own_expert5_out], axis=1)
    own_experts = tensor.reshape(own_expert, (own_expert5_out.shape[3], -1))
    own_features, own_gating_out = build_gating(x, tparams, options)

    opp_expert1_out, opp_expert1 = build_expert1(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options, use_noise, trng)
    opp_expert2_out, opp_expert2 = build_expert2(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options, use_noise, trng)
    opp_expert3_out, opp_expert3 = build_expert3(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options, use_noise, trng)
    opp_expert4_out, opp_expert4 = build_expert4(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options, use_noise, trng)
    opp_expert5_out, opp_expert5 = build_expert5(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options, use_noise, trng)
    opp_expert = tensor.stack([opp_expert1_out, opp_expert2_out, opp_expert3_out, opp_expert4_out, opp_expert5_out], axis=1)
    opp_experts = tensor.reshape(opp_expert, (opp_expert5_out.shape[3], -1))
    opp_features, opp_gating_out = build_gating(x.transpose((0, 1, 3, 2))[:, [1, 0], :, :], tparams, options)

    # Gating * Experts
    own_mix = tensor.dot(own_gating_out, own_experts)
    own_mixture = tensor.reshape(own_mix, (own_mix.shape[1], -1, own_mix.shape[3], own_mix.shape[0]))
    opp_mix = tensor.dot(opp_gating_out, opp_experts)
    opp_mixture = tensor.reshape(opp_mix, (opp_mix.shape[1], -1, opp_mix.shape[3], opp_mix.shape[0]))

    # AR
    ar, weighted_feature_list, br_list = build_ar_layers(x, tparams, options,
                                                         (own_features,
                                                          opp_features),
                                                         (own_mixture,
                                                          opp_mixture))

    ar_layers = options['ar_layers']
    out = get_layer('output')[1](tparams, ar[0][0], options)

    intermediate_fns = {'ar': ar,
                        'own_features': own_features,
                        'opp_features': opp_features,
                        'hidden1': own_mixture,
                        'hidden2': opp_mixture,
                        'weighted_feature_list': weighted_feature_list,
                        'br_list': br_list}
    if not options['debug']:
        return trng, use_noise, out

    else:
        return trng, use_noise, out, intermediate_fns


